{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from main  import *\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']=\"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "get_optimizer() takes 2 positional arguments but 3 were given",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-70d2deb2b085>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_ckpt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_ckpt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_optimizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0mcriterion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_gpu\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: get_optimizer() takes 2 positional arguments but 3 were given"
     ]
    }
   ],
   "source": [
    "opt = Config()\n",
    "opt = Config()    \n",
    "opt.caption_data_path = 'caption.pth' # 原始数据\n",
    "opt.test_img = 'img/example.jpeg' # 输入图片\n",
    "opt.use_gpu = False  # 是否使用GPU(没必要)\n",
    "#opt.model_ckpt='caption_0914_1947' # 预训练的模型\n",
    "opt.img_feature_path = 'results.pth'\n",
    "\n",
    "# 数据\n",
    "vis = Visualizer(env = opt.env)\n",
    "dataloader = get_dataloader(opt)\n",
    "_data = dataloader.dataset._data\n",
    "word2ix,ix2word = _data['word2ix'],_data['ix2word']\n",
    "\n",
    "# 模型\n",
    "model = CaptionModel(opt,word2ix,ix2word)\n",
    "if opt.model_ckpt:\n",
    "    model.load(opt.model_ckpt)\n",
    "optimizer = model.get_optimizer(opt.lr1,opt.lr2)\n",
    "criterion = t.nn.CrossEntropyLoss()\n",
    "if opt.use_gpu:\n",
    "    model.cuda()\n",
    "    criterion.cuda()\n",
    "\n",
    "for epoch in range(opt.epoch):        \n",
    "    loss_meter.reset()\n",
    "    for ii,(imgs, (captions, lengths),indexes)  in tqdm.tqdm(enumerate(dataloader)):\n",
    "        # 训练\n",
    "        optimizer.zero_grad()\n",
    "        input_captions = captions[:-1]\n",
    "        if opt.use_gpu:\n",
    "            imgs = imgs.cuda()\n",
    "            captions = captions.cuda()\n",
    "        imgs = Variable(imgs)\n",
    "        captions = Variable(captions)\n",
    "        input_captions = captions[:-1]\n",
    "        target_captions = pack_padded_sequence(captions,lengths)[0]\n",
    "        score,_ = model(imgs,input_captions,lengths)\n",
    "        loss = criterion(score,target_captions)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        loss_meter.add(loss.data[0])\n",
    "\n",
    "        # 可视化\n",
    "        if (ii+1)%opt.plot_every ==0:\n",
    "            if os.path.exists(opt.debug_file):\n",
    "                ipdb.set_trace()\n",
    "\n",
    "            vis.plot('loss',loss_meter.value()[0])\n",
    "\n",
    "            # 可视化原始图片 + 可视化人工的描述语句\n",
    "            raw_img = _data['ix2id'][indexes[0]]\n",
    "            img_path=opt.img_path+raw_img\n",
    "            raw_img = Image.open(img_path).convert('RGB')\n",
    "            raw_img = tv.transforms.ToTensor()(raw_img)\n",
    "\n",
    "            raw_caption = captions.data[:,0]\n",
    "            raw_caption = ''.join([_data['ix2word'][ii] for ii in raw_caption])\n",
    "            vis.text(raw_caption,u'raw_caption')\n",
    "            vis.img('raw',raw_img,caption=raw_caption)\n",
    "\n",
    "            # 可视化网络生成的描述语句\n",
    "            results = model.generate(imgs.data[0])\n",
    "            vis.text('</br>'.join(results),u'caption')\n",
    "    model.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt.use_gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer = model.get_optimizer(opt.lr1)\n",
    "criterion = t.nn.CrossEntropyLoss()\n",
    "if opt.use_gpu:\n",
    "    model.cuda()\n",
    "    criterion.cuda()\n",
    "\n",
    "# 统计\n",
    "loss_meter = meter.AverageValueMeter()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt.img_path='/home/cy/caption_data/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "151it [01:45,  1.31it/s]Process Process-24:\n",
      "Process Process-22:\n",
      "Process Process-21:\n",
      "Process Process-23:\n",
      "Traceback (most recent call last):\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 249, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 249, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py\", line 34, in _worker_loop\n",
      "    r = index_queue.get()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py\", line 34, in _worker_loop\n",
      "    r = index_queue.get()\n",
      "Traceback (most recent call last):\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.5/multiprocessing/queues.py\", line 343, in get\n",
      "    res = self._reader.recv_bytes()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/queues.py\", line 342, in get\n",
      "    with self._rlock:\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 249, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/connection.py\", line 216, in recv_bytes\n",
      "    buf = self._recv_bytes(maxlength)\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 249, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/lib/python3.5/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py\", line 34, in _worker_loop\n",
      "    r = index_queue.get()\n",
      "  File \"/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py\", line 34, in _worker_loop\n",
      "    r = index_queue.get()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/queues.py\", line 342, in get\n",
      "    with self._rlock:\n",
      "  File \"/usr/lib/python3.5/multiprocessing/connection.py\", line 407, in _recv_bytes\n",
      "    buf = self._recv(4)\n",
      "  File \"/usr/lib/python3.5/multiprocessing/synchronize.py\", line 96, in __enter__\n",
      "    return self._semlock.__enter__()\n",
      "  File \"/usr/lib/python3.5/multiprocessing/queues.py\", line 342, in get\n",
      "    with self._rlock:\n",
      "  File \"/usr/lib/python3.5/multiprocessing/connection.py\", line 379, in _recv\n",
      "    chunk = read(handle, remaining)\n",
      "KeyboardInterrupt\n",
      "  File \"/usr/lib/python3.5/multiprocessing/synchronize.py\", line 96, in __enter__\n",
      "    return self._semlock.__enter__()\n",
      "KeyboardInterrupt\n",
      "KeyboardInterrupt\n",
      "  File \"/usr/lib/python3.5/multiprocessing/synchronize.py\", line 96, in __enter__\n",
      "    return self._semlock.__enter__()\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-4d8a361dbb70>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0mscore\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minput_captions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget_captions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mloss_meter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m    154\u001b[0m                 \u001b[0mVariable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    155\u001b[0m         \"\"\"\n\u001b[0;32m--> 156\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    158\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(variables, grad_variables, retain_graph, create_graph, retain_variables)\u001b[0m\n\u001b[1;32m     96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m---> 98\u001b[0;31m         variables, grad_variables, retain_graph)\n\u001b[0m\u001b[1;32m     99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
